import taichi as ti

from ..shape import Shape
from ..utils import Vec3, Inf, Ray3
from ..config import EPS_CUBOID_FACE


class Cuboid(Shape):
    """长方体"""

    def __init__(self, w: float = 1, h: float = None, d: float = None):
        super().__init__()
        if h is None and d is None:
            h = d = w
        Shape.shapes[self.index].attribute[0] = w / 2
        Shape.shapes[self.index].attribute[1] = h / 2
        Shape.shapes[self.index].attribute[2] = d / 2

    @ti.func
    def intersect(self: int, ray: Ray3, closest: float):
        whd_2 = Shape.shapes[self].attribute[:3]
        t0 = (-whd_2 - ray.origin) / ray.direct
        t1 = (whd_2 - ray.origin) / ray.direct
        t0max, t1min = min(t0, t1).max(), max(t0, t1).min()
        dist, norm = Inf, Vec3(0)
        if min(t1min, closest) >= ti.max(t0max, EPS_CUBOID_FACE):
            dist = t0max if t0max > EPS_CUBOID_FACE else t1min
            point = ray.at(dist) / whd_2
            eps = abs(abs(point) - 1)
            norm = point * (eps == eps.min())
            norm = norm.normalized()
        return dist, norm
